"""

"""

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from pydantic.types import conint
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.db.base_class import Base
from app.utils.custom_exc import CustomException
from app.utils.filters_base import Filters, OrderBy

ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
    def __init__(self, model: Type[ModelType]):
        """
        CRUD object with default methods to Create, Read, Update, Delete (CRUD).

        **Parameters**

        * `model`: A SQLAlchemy model class
        * `schema`: A Pydantic model (schema) class
        """
        self.model = model

    def get(self, db: Session, id: Any) -> Optional[ModelType]:
        return db.query(self.model).filter(self.model.id == id, self.model.is_delete == 0).first()

    def get_multi(
        self, db: Session, *, page: int = 0, page_size: int = 100
    ) -> List[ModelType]:
        temp_page = (page - 1) * page_size
        return db.query(self.model).filter(self.model.is_delete == 0).offset(temp_page).limit(page_size).all()

    def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
        obj_in_data = jsonable_encoder(obj_in)
        db_obj = self.model(**obj_in_data)  # type: ignore
        db.add(db_obj)
        db.commit()
        db.refresh(db_obj)
        return db_obj

    def update(
        self,
        db: Session,
        *,
        db_obj: ModelType,
        obj_in: Union[UpdateSchemaType, Dict[str, Any]]
    ) -> ModelType:
        obj_data = jsonable_encoder(db_obj)
        if isinstance(obj_in, dict):
            update_data = obj_in
        else:
            update_data = obj_in.dict(exclude_unset=True)
        for field in obj_data:
            if field in update_data:
                setattr(db_obj, field, update_data[field])
        db.add(db_obj)
        db.commit()
        db.refresh(db_obj)
        return db_obj

    def remove(self, db: Session, *, id: int) -> ModelType:
        obj = db.query(self.model).filter(self.model.id == id).update({self.model.is_delete: 1})
        # db.delete(obj)
        db.commit()
        return obj


class ModelCRUD(object):
    model: ModelType = None

    def get_object(self, db: Session, id: Any) -> Optional[ModelType]:

        instance = db.query(self.model).filter(self.model.id == id, self.model.is_delete == 0).first()  # type: ignore
        if instance is None:
            raise CustomException(err_desc="未找到数据")
        return instance

    def create(self, db: Session, *, data: CreateSchemaType) -> ModelType:
        obj_in_data = jsonable_encoder(data)
        # print(obj_in_data)
        db_obj = self.model(**obj_in_data)   # type: ignore
        db.add(db_obj)
        db.commit()
        db.refresh(db_obj)
        return db_obj

    def update(self, db: Session, *, instance: ModelType, data: UpdateSchemaType):
        obj_data = jsonable_encoder(instance)
        if isinstance(data, dict):
            update_data = data
        else:
            update_data = data.dict(exclude_unset=True)
        for field in obj_data:
            if field in update_data:
                setattr(instance, field, update_data[field])
        db.add(instance)
        db.commit()
        db.refresh(instance)
        return instance

    def destroy(self, db: Session, *, req_body: DeleteSchemaType):
        req_body = jsonable_encoder(req_body)
        id_list = req_body.get('id', [])
        obj = db.query(self.model).filter(self.model.id.in_(id_list))
        # db.delete(obj)
        obj.delete()
        db.commit()
        return obj

    def delete(self, db: Session, *, req_body: DeleteSchemaType):
        req_body = jsonable_encoder(req_body)
        id_list = req_body.get('id', [])
        obj = db.query(self.model).filter(self.model.id.in_(id_list)).update({self.model.is_delete: 1})
        # db.delete(obj)
        db.commit()
        return obj

    @staticmethod
    def serializer(queryset, only: tuple = (), date_format: str = None, datetime_format: str = None,
                   time_format: str = None):
        return [instance.to_dict(only=only, date_format=date_format, datetime_format=datetime_format,
                                 time_format=time_format) for instance in queryset]

    def filter(self, queryset, req_data: Any = None, able_filter_list: list = None):

        if able_filter_list:
            result = Filters(queryset, self.model, req_data, able_filter_list)
            queryset = result.data
        return queryset

    def order_by(self, queryset, order_by_fields: tuple = None):
        if order_by_fields:
            result = OrderBy(queryset, self.model, order_by_fields)
            queryset = result.data
        return queryset

    def get_page_queryset(self, db: Session, *, page: int = 1, page_size: conint(le=50) = 10, req_data: Any = None, able_filter_list: list = None, order_by_fields: tuple = None):
        # -------------- 统计总数 -------------------------
        temp_page = (page - 1) * page_size
        total_queryset = db.query(func.count(self.model.id)).filter(self.model.is_delete == 0)
        total_queryset = self.filter(total_queryset, req_data, able_filter_list)
        total = total_queryset.scalar()
        # -------------- 过滤查找数据 -------------------------
        queryset = db.query(self.model).filter(self.model.is_delete == 0)
        queryset = self.filter(queryset, req_data, able_filter_list)
        queryset = self.order_by(queryset, order_by_fields)
        queryset = queryset.offset(temp_page).limit(page_size).all()
        return queryset, total

    def get_queryset_pagination(self, queryset, page: int = 1, page_size: conint(le=50) = 10):
        temp_page = (page - 1) * page_size
        total = queryset.count()
        queryset = queryset.offset(temp_page).limit(page_size).all()
        return queryset, total

    def get_all_queryset(self, db: Session, *, req_data: Any = None, able_filter_list: list = None, order_by_fields: tuple = None):
        queryset = db.query(self.model).filter(self.model.is_delete == 0)
        queryset = self.filter(queryset, req_data, able_filter_list)
        queryset = self.order_by(queryset, order_by_fields)
        return queryset

    def queryset(self, db: Session):
        queryset = db.query(self.model).filter(self.model.is_delete == 0)
        return queryset

